--- title: CycleGAN training loop keywords: fastai sidebar: home_sidebar summary: "Defines the loss and training loop functions/classes for CycleGAN." description: "Defines the loss and training loop functions/classes for CycleGAN." ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

The Loss Function

Let's start out by writing the loss function for the CycleGAN model. The main loss used to train the generators. It has three parts:

  • the classic GAN loss: they must make the discriminator believe their images are real.
  • identity loss: if they are given an image from the domain they are trying to imitate, they should return the same thing
  • cycle loss: if an image from domain A goes through the generator that imitates domain B then through the generator that imitates domain A, it should be reconstructed as the same initial image. Same for domain B and switching the generators
{% raw %}
{% endraw %} {% raw %}

class CycleGANLoss[source]

CycleGANLoss(cgan:Module, l_A:float=10.0, l_B:float=10, l_idt:float=0.5, lsgan:bool=True) :: Module

CycleGAN loss function. The individual loss terms are also atrributes of this class that are accessed by fastai for recording during training.

Attributes:

self.cgan (nn.Module): The CycleGAN model.

self.l_A (float): lambda_A, weight of domain A losses.

self.l_B (float): lambda_B, weight of domain B losses.

self.l_idt (float): lambda_idt, weight of identity lossees.

self.crit (AdaptiveLoss): The adversarial loss function (either a BCE or MSE loss depending on lsgan argument)

self.real_A and self.real_B (fastai.torch_core.TensorImage): Real images from domain A and B.

self.id_loss_A (torch.FloatTensor): The identity loss for domain A calculated in the forward function

self.id_loss_B (torch.FloatTensor): The identity loss for domain B calculated in the forward function

self.gen_loss (torch.FloatTensor): The generator loss calculated in the forward function

self.cyc_loss (torch.FloatTensor): The cyclic loss calculated in the forward function

{% endraw %} {% raw %}

CycleGANLoss.__init__[source]

CycleGANLoss.__init__(cgan:Module, l_A:float=10.0, l_B:float=10, l_idt:float=0.5, lsgan:bool=True)

Constructor for CycleGAN loss.

Arguments:

cgan (nn.Module): The CycleGAN model.

l_A (float): weight of domain A losses. (default=10)

l_B (float): weight of domain B losses. (default=10)

l_idt (float): weight of identity losses. (default=0.5)

lsgan (bool): Whether or not to use LSGAN objective. (default=True)

{% endraw %} {% raw %}

CycleGANLoss.set_input[source]

CycleGANLoss.set_input(input)

set self.real_A and self.real_B for future loss calculation

{% endraw %} {% raw %}

CycleGANLoss.forward[source]

CycleGANLoss.forward(output, target)

Forward function of the CycleGAN loss function. The generated images are passed in as output (which comes from the model) and the generator loss is returned.

{% endraw %}

Training loop callback

Let's now write the main callback to train a CycleGAN model.

Fastai's callback system is very flexible, allowing us to adjust the traditional training loop in any conceivable way possible. Let's use it for GAN training.

We have the _set_trainable function that is called with arguments telling which networks need to be put in training mode or which need to be frozen.

When we start training before_train, we define separate optimizers. self.opt_G for the generators and self.opt_D for the discriminators. Then we put the generators in training mode (with _set_trainable).

Before passing the batch into the model (before_batch), we have to fix it since the domain B image was kept as the target, but it also needs to be passed into the model. We also set the inputs for the loss function.

In after_batch, we calculate the discriminator losses, backpropagate, and update the weights of both the discriminators. The main training loop will train the generators.

{% raw %}
{% endraw %} {% raw %}

class CycleGANTrainer[source]

CycleGANTrainer() :: Callback

Learner Callback for training a CycleGAN model.

{% endraw %} {% raw %}

CycleGANTrainer._set_trainable[source]

CycleGANTrainer._set_trainable(disc=False)

Put the generators or discriminators in training mode depending on arguments.

{% endraw %} {% raw %}

CycleGANTrainer.after_batch[source]

CycleGANTrainer.after_batch(**kwargs)

Discriminator training loop

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class ShowCycleGANImgsCallback[source]

ShowCycleGANImgsCallback(imgA:bool=False, imgB:bool=True, show_img_interval:int=10) :: Callback

Update the progress bar with input and prediction images

{% endraw %} {% raw %}

ShowCycleGANImgsCallback.__init__[source]

ShowCycleGANImgsCallback.__init__(imgA:bool=False, imgB:bool=True, show_img_interval:int=10)

If imgA is True, display B-to-A conversion example during training. If imgB is True, display A-to-B conversion example. Show images every show_img_interval epochs.

{% endraw %} {% raw %}

ShowCycleGANImgsCallback.after_epoch[source]

ShowCycleGANImgsCallback.after_epoch()

Update images

{% endraw %}

CycleGAN LR scheduler

The original CycleGAN paper started with a period of constant learning rate and a period of linearly decaying learning rate. Let's make a scheduler to implement this (with other possibilities as well). Fastai already comes with many types of hyperparameter schedules, and new ones can be created by combining existing ones. Let's see how to do this:

{% raw %}
{% endraw %} {% raw %}

combined_flat_anneal[source]

combined_flat_anneal(pct:float, start_lr:float, end_lr:float=0, curve_type:str='linear')

Create a schedule with constant learning rate start_lr for pct proportion of the training, and a curve_type learning rate (till end_lr) for remaining portion of training.

Arguments: pct (float): Proportion of training with a constant learning rate.

start_lr (float): Desired starting learning rate, used for beginnning pct of training.

end_lr (float): Desired end learning rate, training will conclude at this learning rate.

curve_type (str): Curve type for learning rate annealing. Options are 'linear', 'cosine', and 'exponential'.

{% endraw %} {% raw %}
p = torch.linspace(0.,1,200)
plt.plot(p, [combined_flat_anneal(0.5,1,1e-2,curve_type='linear')(o) for o in p],label = 'linear annealing')
plt.plot(p, [combined_flat_anneal(0.5,1,1e-2,curve_type='cosine')(o) for o in p],label = 'cosine annealing')
plt.plot(p, [combined_flat_anneal(0.5,1,1e-2,curve_type='exponential')(o) for o in p],label = 'exponential annealing')
plt.legend()
plt.title('Constant+annealing LR schedules')
Text(0.5,1,'Constant+annealing LR schedules')
{% endraw %}

Now that we have the learning rate schedule, we can write a quick training function that can be added as a method to Learner using @patch decorator. Function is inspired by this code.

{% raw %}
{% endraw %} {% raw %}

Learner.fit_flat_lin[source]

Learner.fit_flat_lin(n_epochs:int=100, n_epochs_decay:int=100, start_lr:float=None, end_lr:float=0, curve_type:str='linear', wd:float=None, cbs=None, reset_opt=False)

Fit self.model for n_epoch at flat start_lr before curve_type annealing to end_lr with weight decay of wd and callbacks cbs.

{% endraw %} {% raw %}
from fastai.test_utils import *
{% endraw %} {% raw %}
learn = synth_learner()
learn.fit_flat_lin(n_epochs=2,n_epochs_decay=2)
epoch train_loss valid_loss time
0 12.987411 7.431311 00:00
1 11.110889 4.981249 00:00
2 9.148016 3.381751 00:00
3 7.668513 2.934142 00:00
{% endraw %} {% raw %}
learn.recorder.plot_sched()
{% endraw %}

CycleGAN Learner construction

Below, we now define a method for initializing a Learner with the CycleGAN model and training callback.

{% raw %}
{% endraw %} {% raw %}

cycle_learner[source]

cycle_learner(dls:DataLoader, m:CycleGAN, opt_func='Adam', show_imgs:bool=True, imgA:bool=True, imgB:bool=True, show_img_interval:bool=10, metrics:list=[], cbs:list=[], loss_func=None, lr=0.001, splitter='trainable_params', path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Initialize and return a Learner object with the data in dls, CycleGAN model m, optimizer function opt_func, metrics metrics, and callbacks cbs. Additionally, if show_imgs is True, it will show intermediate predictions during training. It will show domain B-to-A predictions if imgA is True and/or domain A-to-B predictions if imgB is True. Additionally, it will show images every show_img_interval epochs. OtherLearner` arguments can be passed as well.

{% endraw %}

Quick Test

{% raw %}
horse2zebra = untar_data('https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip')
{% endraw %} {% raw %}
folders = horse2zebra.ls().sorted()
{% endraw %} {% raw %}
trainA_path = folders[2]
trainB_path = folders[3]
testA_path = folders[0]
testB_path = folders[1]
{% endraw %} {% raw %}
dls = get_dls(trainA_path, trainB_path,num_A=100)
{% endraw %} {% raw %}
cycle_gan = CycleGAN(3,3,64)
learn = cycle_learner(dls, cycle_gan,show_img_interval=1)
{% endraw %} {% raw %}
learn.show_training_loop()
Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback, ShowCycleGANImgsCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, CycleGANTrainer, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : [CycleGANTrainer]
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - after_backward : []
         - after_step     : [CycleGANTrainer]
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, CycleGANTrainer, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, CycleGANTrainer, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : [Recorder, ShowCycleGANImgsCallback]
End Fit
 - after_cancel_fit: []
 - after_fit      : [ProgressCallback]
{% endraw %} {% raw %}
test_eq(type(learn),Learner)
{% endraw %} {% raw %}
learn.lr_find()
/home/tmabraham/anaconda3/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")
SuggestedLRs(lr_min=0.00036307806149125097, lr_steep=0.00010964782268274575)
{% endraw %} {% raw %}
learn.fit_flat_lin(5,5,2e-4)
epoch train_loss id_loss_A id_loss_B gen_loss_A gen_loss_B cyc_loss_A cyc_loss_B D_A_loss D_B_loss time
0 11.297443 1.644002 1.820261 0.682393 0.316483 3.379627 3.842381 0.433729 0.433729 00:17
1 9.672276 1.278216 1.287857 0.260192 0.273492 2.707866 2.862581 0.259786 0.259786 00:17
2 8.733617 1.176147 1.101748 0.308695 0.305979 2.483482 2.464781 0.245747 0.245747 00:17
3 8.317053 1.166875 1.125941 0.319255 0.313545 2.428372 2.462454 0.234355 0.234355 00:17
4 7.851682 1.081876 1.016431 0.318608 0.317330 2.283650 2.236785 0.234944 0.234944 00:17
5 7.634328 1.041350 1.080650 0.309708 0.310316 2.259288 2.341902 0.234397 0.234397 00:17
6 7.319435 0.986258 0.959184 0.311627 0.311700 2.150756 2.143827 0.228811 0.228811 00:17
7 6.956178 0.949250 0.880518 0.325909 0.329119 2.011847 1.919625 0.225709 0.225709 00:17
8 6.605078 0.952394 0.797397 0.325455 0.327534 1.978229 1.709552 0.223765 0.223765 00:17
9 6.335949 0.892926 0.826303 0.339374 0.314645 1.819247 1.767499 0.221000 0.221000 00:17
{% endraw %} {% raw %}
learn.recorder.plot_loss(with_valid=False)
{% endraw %}